/**
 * 
 */
package edu.umd.clip.lm.tools;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Vector;
import java.util.Map;
import java.util.HashSet;
import java.util.BitSet;
import java.util.Iterator;

import org.w3c.dom.*;
import javax.xml.parsers.*;
import org.xml.sax.*;
import  org.w3c.dom.bootstrap.DOMImplementationRegistry;
import  org.w3c.dom.Document;
import  org.w3c.dom.ls.*;

import java.io.*;
import java.nio.charset.Charset;

import edu.umd.clip.lm.factors.*;
import edu.umd.clip.lm.util.tree.*;
import edu.umd.clip.lm.util.algo.BrownAlgorithm;
import edu.umd.clip.lm.model.*;
import edu.umd.clip.lm.util.*;

import edu.berkeley.nlp.util.*;
/**
 * @author Denis Filimonov <den@cs.umd.edu>
 *
 */
public class HFTTrainer {

	public static class Options {
        @Option(name = "-input", required = false, usage = "Training data file (Default: stdin)")
		public String input;
        @Option(name = "-config", usage = "XML config file")
		public String config;
        @Option(name = "-factorDesc", usage = "Factor description XML file")
		public String factorDesc;
	}
	/**
	 * @param args
	 */
	public static void main(String[] args) {
        OptionParser optParser = new OptionParser(Options.class);
        Options opts = (Options) optParser.parse(args, true);

		Document doc = null;
		DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
		DocumentBuilder db = null;
		try {
			db = dbf.newDocumentBuilder();
		}catch(ParserConfigurationException pce) {
			pce.printStackTrace();
			return;
		}
		
		try {
			doc = db.parse(new File(opts.factorDesc));
		} catch (SAXException e1) {
			// TODO Auto-generated catch block
			e1.printStackTrace();
		} catch (IOException e1) {
			// TODO Auto-generated catch block
			e1.printStackTrace();
		}
		if (doc == null) {
			return;
		}
		Element rootElement = doc.getDocumentElement();
		//NodeList nodelist = rootElement.getElementsByTagName(FactorTupleDescription.XML_ELEMENT_NAME);
		//assert(nodelist != null && nodelist.getLength() == 1);
		//FactorTupleDescription tupleDesc = FactorTupleDescription.fromXML((Element) nodelist.item(0));
		FactorTupleDescription tupleDesc = FactorTupleDescription.fromXML(rootElement);
		
		byte[] hiddenFactors = tupleDesc.getHiddenFactors();
		FactorDescription[] descs = tupleDesc.getDescriptions();
		
		AnyaryTree<FactorTuple> root = new AnyaryTree<FactorTuple>(null);
		
		// TODO: rewrite it in a generic manner to support arbitrary depth of dependent factors
		{
			byte[] independentFactors = new byte[hiddenFactors.length];
			int count = 0;
			for(byte i=0; i<hiddenFactors.length; ++i) {
				if (descs[i].getParent() == null) {
					independentFactors[count++] = i;
				}
			}
			if (independentFactors.length != count) {
				independentFactors = Arrays.copyOf(independentFactors, count);
			}
			iterateThroughSpace(root, tupleDesc, independentFactors, 0, FactorTuple.getValues(tupleDesc.createTuple()));
			
			if (count < hiddenFactors.length) {
				// there are dependent factors
				byte[] dependentFactors = new byte[hiddenFactors.length - count];
				count = 0;
				for(byte i=0; i<hiddenFactors.length; ++i) {
					if (descs[i].getParent() != null) {
						dependentFactors[count++] = i;
					}
				}
				HashSet<AnyaryTree<FactorTuple>> children = root.getChildren();
				if (children != null) {
					for(AnyaryTree<FactorTuple> child : children) {
						iterateThroughSpace(child, tupleDesc, dependentFactors, 0, child.getPayload().getValues().clone());
					}
				}
			}
		}
		
		try {
			FileWriter treeOutput = new FileWriter("tree-before.dot");
			GraphvizOutput graph = new GraphvizOutput("HFT");
			treeOutput.write(graph.getOutput(root));
			treeOutput.flush();
		} catch(IOException e) {
			e.printStackTrace();
		}
		
		// now collect the statistics
		// TODO: LM order is assumed to be 3
		FLMInputParser parser = new FLMInputParser(tupleDesc);
		long[] startTuples = new long[1];
		startTuples[0] = tupleDesc.createStartTuple();
		//startTuples[1] = tupleDesc.createStartTuple();
		parser.setStartTuples(startTuples);
		long[] endTuples = new long[1];
		endTuples[0] = tupleDesc.createEndTuple();
		parser.setEndTuples(endTuples);
		
		
		HashMap<FactorTuple, HashMap<FactorTuple, Double>> counts = new HashMap<FactorTuple, HashMap<FactorTuple, Double>>(100);
		double totalCounts = 0.0;
		try {
			InputStreamReader reader = new InputStreamReader(opts.input == null ? System.in : new FileInputStream(opts.input), 
					Charset.forName("UTF-8"));
			BufferedReader inputData = new BufferedReader(reader);

			byte[] overtFactors = tupleDesc.getOvertFactors();
			
			for(String line = inputData.readLine(); line != null; line = inputData.readLine()) {
				long[] sentence = parser.parseSentence(line);
				// clear overt factors
				for(int i=0; i<sentence.length; ++i) {
					sentence[i] &= FactorTuple.getHiddenMask();
				}
				
				for(int i=0; i<sentence.length-1; ++i) {
					FactorTuple tuple1 = new FactorTuple(sentence[i]);
					FactorTuple tuple2 = new FactorTuple(sentence[i+1]);
					HashMap<FactorTuple, Double> map = counts.get(tuple1);
					if (map == null) {
						map = new HashMap<FactorTuple, Double>(counts.size());
						counts.put(tuple1, map);
					}
					Double count = map.get(tuple2);
					if (count == null) {
						map.put(tuple2, new Double(1.0));
					} else {
						map.put(tuple2, new Double(count.doubleValue() + 1.0));
					}
					totalCounts += 1.0;
					//System.out.println("size = " + Integer.toString(counts.size()));
				}
			}
		} catch(IOException e) {
			e.printStackTrace();
		}
		
		
		// the Brown algorithm
		BrownAlgorithm<FactorTuple> algo = new BrownAlgorithm<FactorTuple>(root);
		
		for(Map.Entry<FactorTuple, HashMap<FactorTuple, Double>> entry : counts.entrySet()) {
			// normalize
			for(Map.Entry<FactorTuple, Double> entry2 : entry.getValue().entrySet()) {
				entry2.setValue(new Double(entry2.getValue().doubleValue() / totalCounts));
			}
			algo.setDistribution(entry.getKey(), entry.getValue());
		}
		
		AnyaryTree<FactorTuple> result = algo.process();

		try {
			FileWriter treeOutput = new FileWriter("tree-after.dot");
			GraphvizOutput graph = new GraphvizOutput("HFT") {
				protected <T> String getText(T t) {
					return ((FactorTuple) t).toStringNoNull();
				}
			};
			treeOutput.write(graph.getOutput(result));
			treeOutput.close();
		} catch(IOException e) {
			e.printStackTrace();
		}

		Experiment.initialize(opts.config);
		Experiment experiment = Experiment.getInstance();
		experiment.setTupleDescription(tupleDesc);
		experiment.setTupleParser(parser);
		
		HFT<HFTPayload> hft = new HFT<HFTPayload>(makeHFT(new BinaryPrefix(new BitSet(0), 0), result));
		experiment.setHFT(hft);
		try {
			experiment.saveConfig(opts.config);
		} catch(IOException e) {
			e.printStackTrace();
		}
		Document hftDoc = db.getDOMImplementation().createDocument(null, null, null);
		Element elem = hft.createXML(hftDoc);
		hftDoc.appendChild(elem);
		try {
			FileWriter hftOutput = new FileWriter("hft.xml");
			
			DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance();

			DOMImplementationLS impl = 
			    (DOMImplementationLS)registry.getDOMImplementation("LS");
			LSSerializer writer = impl.createLSSerializer();
			writer.getDomConfig().setParameter("format-pretty-print", Boolean.TRUE);
			LSOutput output = impl.createLSOutput();
			output.setEncoding("UTF-8");
			output.setCharacterStream(hftOutput);
			writer.write(hftDoc.getDocumentElement(), output);
			hftOutput.close();
		} catch(Exception e) {
			e.printStackTrace();
		}
	}

	private static void iterateThroughSpace(AnyaryTree<FactorTuple> tree, 
			FactorTupleDescription tupleDesc, 
			byte[] factors, 
			int pos, 
			int[] values) 
	{
		boolean isRoot = tree.getPayload() == null;
		FactorDescription desc = tupleDesc.getDescription(factors[pos]);
		edu.umd.clip.lm.factors.Dictionary dict;
		byte parentIdx = tupleDesc.getParentIndex(factors[pos]);
		if (parentIdx == -1) 
			dict = desc.getDictionary();
		else 
			dict = desc.getDictionary(values[parentIdx]);
		
		DictionaryIterator iter = dict.iterator(isRoot);
		if (iter.hasNext()) {
			// we don't want to remove payload if there are no children
			tree.setPayload(null);
		}
		
		while(iter.hasNext()) {
			int value = iter.next();
			if (Dictionary.isNull(value) || Dictionary.isUnk(value)) continue;
			
			values[factors[pos]] = value;
			if (pos == factors.length-1) {
				tree.addChild(new AnyaryTree<FactorTuple>(new FactorTuple(tupleDesc.createTuple(values))));
			} else {
				iterateThroughSpace(tree, tupleDesc, factors, pos+1, values);
			}
		}
		HashSet<AnyaryTree<FactorTuple>> children = tree.getChildren();
		if (children != null && children.size() == 1) {
			// replace the sole child with grand-children
			AnyaryTree<FactorTuple> child = children.iterator().next(); 
			tree.setChildren(child.getChildren());
			tree.setPayload(child.getPayload());
		}
	}
	
	private static BinaryTree<HFTPayload> makeHFT(BinaryPrefix prefix, AnyaryTree<FactorTuple> node) {
		BinaryTree<HFTPayload> tree = new BinaryTree<HFTPayload>(new HFTPayload(prefix, node.getPayload().getBits()));

		HashSet<AnyaryTree<FactorTuple>> children = node.getChildren();
		if (children != null) {
			assert(children.size() == 2);
			Iterator<AnyaryTree<FactorTuple>> iterator = children.iterator();
			AnyaryTree<FactorTuple> left = iterator.next();
			AnyaryTree<FactorTuple> right = iterator.next();
			
			tree.attachLeft(makeHFT(prefix.appendZero(), left));
			tree.attachRight(makeHFT(prefix.appendOne(), right));
		}
		return tree;
	}
}
